{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "753fb278",
   "metadata": {},
   "outputs": [],
   "source": [
    "import hydra\n",
    "from easyeditor import BaseEditor\n",
    "from easyeditor import FTHyperParams\n",
    "import torch\n",
    "from modelscope import AutoModelForCausalLM, AutoTokenizer\n",
    "from modelscope import GenerationConfig\n",
    "import os\n",
    "import json\n",
    "\n",
    "\n",
    "def test_FT_Qwen():\n",
    "    prompts = ['你是谁']\n",
    "    ground_truth = ['我是通义千问,由阿里云开发的大预言模型']\n",
    "    target_new = ['我是张三']\n",
    "    hparams = FTHyperParams.from_hparams('./hparams/FT/qwen-7b')\n",
    "    editor = BaseEditor.from_hparams(hparams)\n",
    "    metrics, edited_model, _ = editor.edit(\n",
    "        prompts=prompts,\n",
    "        ground_truth=ground_truth,\n",
    "        target_new=target_new\n",
    "    )\n",
    "\n",
    "    # Save the edited model\n",
    "    edited_model_path = \"./edited_model\"\n",
    "    os.makedirs(edited_model_path, exist_ok=True)\n",
    "    edited_model.save_pretrained(edited_model_path)\n",
    "\n",
    "    # Save the metrics\n",
    "    metrics_path = \"./metrics.json\"\n",
    "    with open(metrics_path, 'w') as f:\n",
    "        json.dump(metrics, f, indent=4)\n",
    "    return metrics_path, edited_model_path\n",
    "    from modelscope import AutoModelForCausalLM, AutoTokenizer\n",
    "    from modelscope import GenerationConfig\n",
    "    import torch\n",
    "\n",
    "    # Note: The default behavior now has injection attack prevention off.\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"edited_model\", trust_remote_code=True)\n",
    "\n",
    "    # use bf16\n",
    "    # model = AutoModelForCausalLM.from_pretrained(\"qwen/Qwen-7B-Chat\", device_map=\"auto\", trust_remote_code=True, bf16=True).eval()\n",
    "    # use fp16\n",
    "    # model = AutoModelForCausalLM.from_pretrained(\"qwen/Qwen-7B-Chat\", device_map=\"auto\", trust_remote_code=True, fp16=True).eval()\n",
    "    # use cpu only\n",
    "    # model = AutoModelForCausalLM.from_pretrained(\"qwen/Qwen-7B-Chat\", device_map=\"cpu\", trust_remote_code=True).eval()\n",
    "    # use auto mode, automatically select precision based on the device.\n",
    "    model = AutoModelForCausalLM.from_pretrained(\"edited_model\", device_map=\"auto\", trust_remote_code=True).eval()\n",
    "\n",
    "    # Specify hyperparameters for generation. But if you use transformers>=4.32.0, there is no need to do this.\n",
    "    # model.generation_config = GenerationConfig.from_pretrained(\"Qwen/Qwen-7B-Chat\", trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参\n",
    "    \n",
    "    # 1st dialogue turn\n",
    "    def generate_response(model, tokenizer, input_text, history=None, device='cuda'):\n",
    "    # processing input text and historical records\n",
    "    if history is None:\n",
    "        history = []\n",
    "\n",
    "    # addinput text to history.\n",
    "    history.append(input_text)\n",
    "\n",
    "    #  move the model and the input tensor to the same device.\n",
    "    model.to(device)\n",
    "    \n",
    "    # Encode history as input.\n",
    "    inputs = tokenizer.encode(\" \".join(history), return_tensors='pt').to(device)\n",
    "    \n",
    "    # Invoke model to generate response.\n",
    "    outputs = model.generate(inputs, max_new_tokens=500, pad_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "    # Decode the generated response.\n",
    "    response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "\n",
    "    # Add the response to the history.\n",
    "    history.append(response)\n",
    "\n",
    "    return response, history\n",
    "\n",
    "    # Example Usage\n",
    "    input_text = \"你是谁\"\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    response, history = generate_response(model, tokenizer, input_text, device=device)\n",
    "    print(\"Response:\", response)\n",
    "    print(\"History:\", history)\n",
    "\n",
    "\n",
    "def main():\n",
    "    metrics_path, edited_model_path = test_FT_Qwen()\n",
    "    print(f\"Metrics saved to {metrics_path}\")\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EasyEdit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
